# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com)
@description:
part of the code from https://github.com/phidatahq/phidata
"""
from __future__ import annotations
import collections.abc
import inspect

from os import getenv
from uuid import uuid4
from types import GeneratorType
from typing import Any, Optional, Dict, Callable
from pydantic import BaseModel, Field, ConfigDict, field_validator, PrivateAttr

from agentica.utils.log import logger, set_log_level_to_debug
from agentica.agent import Agent
from agentica.run_response import RunResponse
from agentica.memory import WorkflowMemory, WorkflowRun
from agentica.storage.workflow.base import WorkflowStorage
from agentica.utils.misc import merge_dictionaries
from agentica.workflow_session import WorkflowSession


class Workflow(BaseModel):
    # -*- Workflow settings
    # Workflow name
    name: Optional[str] = None
    # Workflow description
    description: Optional[str] = None
    # Workflow UUID (autogenerated if not set)
    workflow_id: Optional[str] = Field(None, validate_default=True)
    # Metadata associated with this workflow
    workflow_data: Optional[Dict[str, Any]] = None

    # -*- User settings
    # ID of the user interacting with this workflow
    user_id: Optional[str] = None
    # Metadata associated with the user interacting with this workflow
    user_data: Optional[Dict[str, Any]] = None

    # -*- Session settings
    # Session UUID (autogenerated if not set)
    session_id: Optional[str] = Field(None, validate_default=True)
    # Session name
    session_name: Optional[str] = None
    # Session state stored in the database
    session_state: Dict[str, Any] = Field(default_factory=dict)

    # -*- Workflow Memory
    memory: WorkflowMemory = WorkflowMemory()

    # -*- Workflow Storage
    storage: Optional[WorkflowStorage] = None
    # WorkflowSession from the database: DO NOT SET MANUALLY
    _workflow_session: Optional[WorkflowSession] = None

    # debug_mode=True enables debug logs
    debug_mode: bool = Field(False, validate_default=True)
    # monitoring=True logs workflow information to phidata.com
    monitoring: bool = getenv("PHI_MONITORING", "false").lower() == "true"
    # telemetry=True logs minimal telemetry for analytics
    # This helps us improve the Agent and provide better support
    telemetry: bool = getenv("PHI_TELEMETRY", "true").lower() == "true"

    # DO NOT SET THE FOLLOWING FIELDS MANUALLY
    # Run ID: DO NOT SET MANUALLY
    run_id: Optional[str] = None
    # Input to the Workflow run: DO NOT SET MANUALLY
    run_input: Optional[Dict[str, Any]] = None
    # Response from the Workflow run: DO NOT SET MANUALLY
    run_response: RunResponse = Field(default_factory=RunResponse)
    # Metadata associated with this session: DO NOT SET MANUALLY
    session_data: Optional[Dict[str, Any]] = None

    # The run function provided by the subclass
    _subclass_run: Callable = PrivateAttr()
    # Parameters of the run function
    _run_parameters: Dict[str, Any] = PrivateAttr()
    # Return type of the run function
    _run_return_type: Optional[str] = PrivateAttr()

    model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)

    @field_validator("workflow_id", mode="before")
    def set_workflow_id(cls, v: Optional[str]) -> str:
        workflow_id = v or str(uuid4())
        logger.debug(f"*********** Worfklow ID: {workflow_id} ***********")
        return workflow_id

    @field_validator("session_id", mode="before")
    def set_session_id(cls, v: Optional[str]) -> str:
        session_id = v or str(uuid4())
        logger.debug(f"*********** Worflow Session ID: {session_id} ***********")
        return session_id

    @field_validator("debug_mode", mode="before")
    def set_log_level(cls, v: bool) -> bool:
        if v or getenv("PHI_DEBUG", "false").lower() == "true":
            set_log_level_to_debug()
            logger.debug("Debug logs enabled")
        return v

    def get_workflow_data(self) -> Dict[str, Any]:
        workflow_data = self.workflow_data or {}
        if self.name is not None:
            workflow_data["name"] = self.name
        return workflow_data

    def get_session_data(self) -> Dict[str, Any]:
        session_data = self.session_data or {}
        if self.session_name is not None:
            session_data["session_name"] = self.session_name
        if len(self.session_state) > 0:
            session_data["session_state"] = self.session_state
        return session_data

    def get_workflow_session(self) -> WorkflowSession:
        """Get a WorkflowSession object, which can be saved to the database"""

        return WorkflowSession(
            session_id=self.session_id,
            workflow_id=self.workflow_id,
            user_id=self.user_id,
            memory=self.memory.to_dict(),
            workflow_data=self.get_workflow_data(),
            user_data=self.user_data,
            session_data=self.get_session_data(),
        )

    def from_workflow_session(self, session: WorkflowSession):
        """Load the existing Workflow from a WorkflowSession (from the database)"""

        # Get the session_id, workflow_id and user_id from the database
        if self.session_id is None and session.session_id is not None:
            self.session_id = session.session_id
        if self.workflow_id is None and session.workflow_id is not None:
            self.workflow_id = session.workflow_id
        if self.user_id is None and session.user_id is not None:
            self.user_id = session.user_id

        # Read workflow_data from the database
        if session.workflow_data is not None:
            # Get name from database and update the workflow name if not set
            if self.name is None and "name" in session.workflow_data:
                self.name = session.workflow_data.get("name")

            # If workflow_data is set in the workflow, update the database workflow_data with the workflow's workflow_data
            if self.workflow_data is not None:
                # Updates workflow_session.workflow_data in place
                merge_dictionaries(session.workflow_data, self.workflow_data)
            self.workflow_data = session.workflow_data

        # Read user_data from the database
        if session.user_data is not None:
            # If user_data is set in the workflow, update the database user_data with the workflow's user_data
            if self.user_data is not None:
                # Updates workflow_session.user_data in place
                merge_dictionaries(session.user_data, self.user_data)
            self.user_data = session.user_data

        # Read session_data from the database
        if session.session_data is not None:
            # Get the session_name from database and update the current session_name if not set
            if self.session_name is None and "session_name" in session.session_data:
                self.session_name = session.session_data.get("session_name")

            # Get the session_state from database and update the current session_state
            if "session_state" in session.session_data:
                session_state_from_db = session.session_data.get("session_state")
                if (
                        session_state_from_db is not None
                        and isinstance(session_state_from_db, dict)
                        and len(session_state_from_db) > 0
                ):
                    # If the session_state is already set, merge the session_state from the database with the current session_state
                    if len(self.session_state) > 0:
                        # This updates session_state_from_db
                        merge_dictionaries(session_state_from_db, self.session_state)
                    # Update the current session_state
                    self.session_state = session_state_from_db

            # If session_data is set in the workflow, update the database session_data with the workflow's session_data
            if self.session_data is not None:
                # Updates workflow_session.session_data in place
                merge_dictionaries(session.session_data, self.session_data)
            self.session_data = session.session_data

        # Read memory from the database
        if session.memory is not None:
            try:
                if "runs" in session.memory:
                    self.memory.runs = [WorkflowRun(**m) for m in session.memory["runs"]]
            except Exception as e:
                logger.warning(f"Failed to load WorkflowMemory: {e}")
        logger.debug(f"-*- WorkflowSession loaded: {session.session_id}")

    def read_from_storage(self) -> Optional[WorkflowSession]:
        """Load the WorkflowSession from storage.

        Returns:
            Optional[WorkflowSession]: The loaded WorkflowSession or None if not found.
        """
        if self.storage is not None and self.session_id is not None:
            self._workflow_session = self.storage.read(session_id=self.session_id)
            if self._workflow_session is not None:
                self.from_workflow_session(session=self._workflow_session)
        return self._workflow_session

    def write_to_storage(self) -> Optional[WorkflowSession]:
        """Save the WorkflowSession to storage

        Returns:
            Optional[WorkflowSession]: The saved WorkflowSession or None if not saved.
        """
        if self.storage is not None:
            self._workflow_session = self.storage.upsert(session=self.get_workflow_session())
        return self._workflow_session

    def load_session(self, force: bool = False) -> Optional[str]:
        """Load an existing session from the database and return the session_id.
        If a session does not exist, create a new session.

        - If a session exists in the database, load the session.
        - If a session does not exist in the database, create a new session.
        """
        # If a workflow_session is already loaded, return the session_id from the workflow_session
        # if session_id matches the session_id from the workflow_session
        if self._workflow_session is not None and not force:
            if self.session_id is not None and self._workflow_session.session_id == self.session_id:
                return self._workflow_session.session_id

        # Load an existing session or create a new session
        if self.storage is not None:
            # Load existing session if session_id is provided
            logger.debug(f"Reading WorkflowSession: {self.session_id}")
            self.read_from_storage()

            # Create a new session if it does not exist
            if self._workflow_session is None:
                logger.debug("-*- Creating new WorkflowSession")
                # write_to_storage() will create a new WorkflowSession
                # and populate self._workflow_session with the new session
                self.write_to_storage()
                if self._workflow_session is None:
                    raise Exception("Failed to create new WorkflowSession in storage")
                logger.debug(f"-*- Created WorkflowSession: {self._workflow_session.session_id}")
                self.log_workflow_session()
        return self.session_id

    def run(self, *args: Any, **kwargs: Any):
        logger.error(f"{self.__class__.__name__}.run() method not implemented.")
        return

    def run_workflow(self, *args: Any, **kwargs: Any):
        self.run_id = str(uuid4())
        self.run_input = {"args": args, "kwargs": kwargs}
        self.run_response = RunResponse(run_id=self.run_id, session_id=self.session_id, workflow_id=self.workflow_id)
        self.read_from_storage()
        result = self._subclass_run(*args, **kwargs)

        # Case 1: The run method returns an Iterator[RunResponse]
        if isinstance(result, (GeneratorType, collections.abc.Iterator)):
            # Initialize the run_response content
            self.run_response.content = ""

            def result_generator():
                for item in result:
                    if isinstance(item, RunResponse):
                        # Update the run_id, session_id and workflow_id of the RunResponse
                        item.run_id = self.run_id
                        item.session_id = self.session_id
                        item.workflow_id = self.workflow_id

                        # Update the run_response with the content from the result
                        if item.content is not None and isinstance(item.content, str):
                            self.run_response.content += item.content
                    else:
                        logger.warning(f"Workflow.run() should only yield RunResponse objects, got: {type(item)}")
                    yield item

                # Add the run to the memory
                self.memory.add_run(WorkflowRun(input=self.run_input, response=self.run_response))
                # Write this run to the database
                self.write_to_storage()
            return result_generator()
        # Case 2: The run method returns a RunResponse
        elif isinstance(result, RunResponse):
            # Update the result with the run_id, session_id and workflow_id of the workflow run
            result.run_id = self.run_id
            result.session_id = self.session_id
            result.workflow_id = self.workflow_id

            # Update the run_response with the content from the result
            if result.content is not None and isinstance(result.content, str):
                self.run_response.content = result.content

            # Add the run to the memory
            self.memory.add_run(WorkflowRun(input=self.run_input, response=self.run_response))
            # Write this run to the database
            self.write_to_storage()
            return result
        else:
            logger.warning(f"Workflow.run() should only return RunResponse objects, got: {type(result)}")
            return None

    def __init__(self, **data):
        super().__init__(**data)
        self.name = self.name or self.__class__.__name__
        # Check if 'run' is provided by the subclass
        if self.__class__.run is not Workflow.run:
            # Store the original run method bound to the instance
            self._subclass_run = self.__class__.run.__get__(self)
            # Get the parameters of the run method
            sig = inspect.signature(self.__class__.run)
            # Convert parameters to a serializable format
            self._run_parameters = {
                name: {
                    "name": name,
                    "default": param.default if param.default is not inspect.Parameter.empty else None,
                    "annotation": (
                        param.annotation.__name__
                        if hasattr(param.annotation, "__name__")
                        else (
                            str(param.annotation).replace("typing.Optional[", "").replace("]", "")
                            if "typing.Optional" in str(param.annotation)
                            else str(param.annotation)
                        )
                    )
                    if param.annotation is not inspect.Parameter.empty
                    else None,
                    "required": param.default is inspect.Parameter.empty,
                }
                for name, param in sig.parameters.items()
                if name != "self"
            }
            # Determine the return type of the run method
            return_annotation = sig.return_annotation
            self._run_return_type = (
                return_annotation.__name__
                if return_annotation is not inspect.Signature.empty and hasattr(return_annotation, "__name__")
                else str(return_annotation)
                if return_annotation is not inspect.Signature.empty
                else None
            )
            # Replace the instance's run method with run_workflow
            object.__setattr__(self, "run", self.run_workflow.__get__(self))
        else:
            # This will log an error when called
            self._subclass_run = self.run
            self._run_parameters = {}
            self._run_return_type = None

    def model_post_init(self, __context: Any) -> None:
        super().model_post_init(__context)
        for field_name, field in self.__fields__.items():
            value = getattr(self, field_name)
            if isinstance(value, Agent):
                value.session_id = self.session_id

    def log_workflow_session(self):
        logger.debug(f"*********** Logging WorkflowSession: {self.session_id} ***********")

    def rename_session(self, session_id: str, name: str):
        if self.storage is None:
            raise ValueError("Storage is not set")
        workflow_session = self.storage.read(session_id)
        if workflow_session is None:
            raise Exception(f"WorkflowSession not found: {session_id}")
        if workflow_session.session_data is not None:
            workflow_session.session_data["session_name"] = name
        else:
            workflow_session.session_data = {"session_name": name}
        self.storage.upsert(workflow_session)

    def delete_session(self, session_id: str):
        if self.storage is None:
            raise ValueError("Storage is not set")
        self.storage.delete_session(session_id)

    def deep_copy(self, *, update: Optional[Dict[str, Any]] = None) -> "Workflow":
        """Create and return a deep copy of this Workflow, optionally updating fields.

        Args:
            update (Optional[Dict[str, Any]]): Optional dictionary of fields for the new Workflow.

        Returns:
            Workflow: A new Workflow instance.
        """
        # Extract the fields to set for the new Workflow
        fields_for_new_workflow = {}

        for field_name in self.model_fields_set:
            field_value = getattr(self, field_name)
            if field_value is not None:
                if isinstance(field_value, Agent):
                    fields_for_new_workflow[field_name] = field_value.deep_copy()
                else:
                    fields_for_new_workflow[field_name] = self._deep_copy_field(field_name, field_value)

        # Update fields if provided
        if update:
            fields_for_new_workflow.update(update)

        # Create a new Workflow
        new_workflow = self.__class__(**fields_for_new_workflow)
        logger.debug(
            f"Created new Workflow: workflow_id: {new_workflow.workflow_id} | session_id: {new_workflow.session_id}"
        )
        return new_workflow

    def _deep_copy_field(self, field_name: str, field_value: Any) -> Any:
        """Helper method to deep copy a field based on its type."""
        from copy import copy, deepcopy

        # For memory, use its deep_copy method
        if field_name == "memory":
            return field_value.deep_copy()

        # For compound types, attempt a deep copy
        if isinstance(field_value, (list, dict, set, WorkflowStorage)):
            try:
                return deepcopy(field_value)
            except Exception as e:
                logger.warning(f"Failed to deepcopy field: {field_name} - {e}")
                try:
                    return copy(field_value)
                except Exception as e:
                    logger.warning(f"Failed to copy field: {field_name} - {e}")
                    return field_value

        # For pydantic models, attempt a deep copy
        if isinstance(field_value, BaseModel):
            try:
                return field_value.model_copy(deep=True)
            except Exception as e:
                logger.warning(f"Failed to deepcopy field: {field_name} - {e}")
                try:
                    return field_value.model_copy(deep=False)
                except Exception as e:
                    logger.warning(f"Failed to copy field: {field_name} - {e}")
                    return field_value

        # For other types, return as is
        return field_value
